"""
Color Profile System for Cinespinner Pro

Modular color extraction algorithms that can be applied to video frame data.
Each profile extracts a single representative color from a collection of pixels.
"""

from abc import ABC, abstractmethod
from typing import Tuple, Optional
from enum import Enum
import numpy as np

# Try to import sklearn for k-means, fall back to simple implementation if not available
try:
    from sklearn.cluster import MiniBatchKMeans
    HAS_SKLEARN = True
except ImportError:
    HAS_SKLEARN = False


class ProfileType(Enum):
    """Available color profile types."""
    AVERAGE = "Average"
    MODE = "Mode"
    VIBRANT = "Vibrant"
    DOMINANT = "Dominant"
    MEDIAN = "Median"
    HIGHLIGHT = "Highlight"
    WEIGHTED = "Weighted"


def rgb_to_hsv(rgb: np.ndarray) -> np.ndarray:
    """
    Convert RGB array to HSV.

    Args:
        rgb: Array of shape (N, 3) with RGB values 0-255

    Returns:
        Array of shape (N, 3) with H (0-360), S (0-1), V (0-1)
    """
    rgb_normalized = rgb.astype(np.float32) / 255.0

    r, g, b = rgb_normalized[:, 0], rgb_normalized[:, 1], rgb_normalized[:, 2]

    max_c = np.maximum(np.maximum(r, g), b)
    min_c = np.minimum(np.minimum(r, g), b)
    delta = max_c - min_c

    # Hue calculation
    h = np.zeros_like(max_c)

    # When max == r
    mask = (max_c == r) & (delta != 0)
    h[mask] = 60 * (((g[mask] - b[mask]) / delta[mask]) % 6)

    # When max == g
    mask = (max_c == g) & (delta != 0)
    h[mask] = 60 * (((b[mask] - r[mask]) / delta[mask]) + 2)

    # When max == b
    mask = (max_c == b) & (delta != 0)
    h[mask] = 60 * (((r[mask] - g[mask]) / delta[mask]) + 4)

    # Saturation
    s = np.where(max_c != 0, delta / max_c, 0)

    # Value
    v = max_c

    return np.stack([h, s, v], axis=1)


def hsv_to_rgb(hsv: np.ndarray) -> np.ndarray:
    """
    Convert HSV values back to RGB.

    Args:
        hsv: Array with H (0-360), S (0-1), V (0-1)

    Returns:
        RGB tuple (0-255)
    """
    h, s, v = hsv[0], hsv[1], hsv[2]

    c = v * s
    x = c * (1 - abs((h / 60) % 2 - 1))
    m = v - c

    if h < 60:
        r, g, b = c, x, 0
    elif h < 120:
        r, g, b = x, c, 0
    elif h < 180:
        r, g, b = 0, c, x
    elif h < 240:
        r, g, b = 0, x, c
    elif h < 300:
        r, g, b = x, 0, c
    else:
        r, g, b = c, 0, x

    return np.array([(r + m) * 255, (g + m) * 255, (b + m) * 255])


def rgb_to_lab(rgb: np.ndarray) -> np.ndarray:
    """
    Convert RGB to CIELAB color space for perceptual operations.

    Args:
        rgb: Array of shape (N, 3) with RGB values 0-255

    Returns:
        Array of shape (N, 3) with L (0-100), a (-128-127), b (-128-127)
    """
    # Normalize RGB to 0-1
    rgb_normalized = rgb.astype(np.float32) / 255.0

    # Apply gamma correction (sRGB to linear)
    mask = rgb_normalized > 0.04045
    rgb_linear = np.where(mask, ((rgb_normalized + 0.055) / 1.055) ** 2.4, rgb_normalized / 12.92)

    # Convert to XYZ (D65 illuminant)
    matrix = np.array([
        [0.4124564, 0.3575761, 0.1804375],
        [0.2126729, 0.7151522, 0.0721750],
        [0.0193339, 0.1191920, 0.9503041]
    ])
    xyz = np.dot(rgb_linear, matrix.T)

    # Normalize by D65 white point
    xyz[:, 0] /= 0.95047
    xyz[:, 1] /= 1.00000
    xyz[:, 2] /= 1.08883

    # Convert XYZ to Lab
    epsilon = 0.008856
    kappa = 903.3

    mask = xyz > epsilon
    f_xyz = np.where(mask, np.cbrt(xyz), (kappa * xyz + 16) / 116)

    L = 116 * f_xyz[:, 1] - 16
    a = 500 * (f_xyz[:, 0] - f_xyz[:, 1])
    b = 200 * (f_xyz[:, 1] - f_xyz[:, 2])

    return np.stack([L, a, b], axis=1)


def lab_to_rgb(lab: np.ndarray) -> np.ndarray:
    """
    Convert CIELAB back to RGB.

    Args:
        lab: Array with L, a, b values

    Returns:
        RGB tuple (0-255)
    """
    L, a, b = lab[0], lab[1], lab[2]

    # Convert Lab to XYZ
    fy = (L + 16) / 116
    fx = a / 500 + fy
    fz = fy - b / 200

    epsilon = 0.008856
    kappa = 903.3

    xr = fx ** 3 if fx ** 3 > epsilon else (116 * fx - 16) / kappa
    yr = ((L + 16) / 116) ** 3 if L > kappa * epsilon else L / kappa
    zr = fz ** 3 if fz ** 3 > epsilon else (116 * fz - 16) / kappa

    # Apply D65 white point
    x = xr * 0.95047
    y = yr * 1.00000
    z = zr * 1.08883

    # Convert XYZ to RGB
    r = x * 3.2404542 + y * -1.5371385 + z * -0.4985314
    g = x * -0.9692660 + y * 1.8760108 + z * 0.0415560
    b = x * 0.0556434 + y * -0.2040259 + z * 1.0572252

    # Apply gamma correction (linear to sRGB)
    rgb_linear = np.array([r, g, b])
    mask = rgb_linear > 0.0031308
    rgb_srgb = np.where(mask, 1.055 * (rgb_linear ** (1/2.4)) - 0.055, 12.92 * rgb_linear)

    # Clamp and scale to 0-255
    rgb_srgb = np.clip(rgb_srgb * 255, 0, 255)

    return rgb_srgb


def filter_dark_pixels(pixels: np.ndarray, threshold: int = 15) -> np.ndarray:
    """
    Remove near-black pixels (letterboxing, shadows).

    Args:
        pixels: RGB pixel array (N, 3)
        threshold: Brightness threshold (0-255)

    Returns:
        Filtered pixel array
    """
    brightness = np.mean(pixels, axis=1)
    mask = brightness > threshold
    filtered = pixels[mask]

    # Return original if too few pixels remain
    if len(filtered) < 10:
        return pixels
    return filtered


def filter_extreme_pixels(pixels: np.ndarray, low: int = 15, high: int = 240) -> np.ndarray:
    """
    Remove near-black and near-white pixels.

    Args:
        pixels: RGB pixel array (N, 3)
        low: Lower brightness threshold
        high: Upper brightness threshold

    Returns:
        Filtered pixel array
    """
    brightness = np.mean(pixels, axis=1)
    mask = (brightness > low) & (brightness < high)
    filtered = pixels[mask]

    if len(filtered) < 10:
        return pixels
    return filtered


def boost_saturation(rgb: Tuple[int, int, int], factor: float = 1.3) -> Tuple[int, int, int]:
    """
    Boost the saturation of an RGB color.

    Args:
        rgb: RGB tuple (0-255)
        factor: Saturation multiplier (1.0 = no change)

    Returns:
        Boosted RGB tuple
    """
    r, g, b = rgb[0] / 255.0, rgb[1] / 255.0, rgb[2] / 255.0

    max_c = max(r, g, b)
    min_c = min(r, g, b)

    if max_c == min_c:
        return rgb  # Grayscale, no saturation to boost

    # Calculate luminance
    l = (max_c + min_c) / 2

    # Calculate current saturation
    if l <= 0.5:
        s = (max_c - min_c) / (max_c + min_c)
    else:
        s = (max_c - min_c) / (2 - max_c - min_c)

    # Boost saturation
    new_s = min(1.0, s * factor)

    # Convert back to RGB
    if new_s == 0:
        return rgb

    if l <= 0.5:
        c = 2 * l * new_s
    else:
        c = (2 - 2 * l) * new_s

    x = c * (1 - abs((((max_c - min_c) / (max_c + min_c) if max_c + min_c != 0 else 0) * 6) % 2 - 1))
    m = l - c / 2

    # Simplified: just interpolate towards/away from gray
    gray = (r + g + b) / 3
    new_r = gray + (r - gray) * (new_s / s if s > 0 else 1)
    new_g = gray + (g - gray) * (new_s / s if s > 0 else 1)
    new_b = gray + (b - gray) * (new_s / s if s > 0 else 1)

    return (
        int(np.clip(new_r * 255, 0, 255)),
        int(np.clip(new_g * 255, 0, 255)),
        int(np.clip(new_b * 255, 0, 255))
    )


class ColorProfile(ABC):
    """Abstract base class for color extraction profiles."""

    name: str = "Base"
    description: str = "Base color profile"

    @abstractmethod
    def extract(self, pixels: np.ndarray) -> Tuple[int, int, int]:
        """
        Extract a representative color from pixel data.

        Args:
            pixels: numpy array of shape (N, 3) with RGB values 0-255

        Returns:
            RGB tuple (r, g, b) with values 0-255
        """
        pass


class AverageProfile(ColorProfile):
    """
    Improved Average Profile

    Uses perceptual LAB color space for averaging, filters dark pixels,
    and boosts saturation to counteract muddy blending.
    """

    name = "Average"
    description = "Smooth blend with saturation boost, dark pixels filtered"

    def __init__(self, filter_dark: bool = True, saturation_boost: float = 1.25):
        self.filter_dark = filter_dark
        self.saturation_boost = saturation_boost

    def extract(self, pixels: np.ndarray) -> Tuple[int, int, int]:
        if len(pixels) == 0:
            return (0, 0, 0)

        # Filter dark pixels
        if self.filter_dark:
            pixels = filter_dark_pixels(pixels, threshold=20)

        if len(pixels) == 0:
            return (0, 0, 0)

        # Convert to LAB for perceptual averaging
        lab_pixels = rgb_to_lab(pixels)

        # Average in LAB space
        avg_lab = np.mean(lab_pixels, axis=0)

        # Convert back to RGB
        avg_rgb = lab_to_rgb(avg_lab)

        result = (int(avg_rgb[0]), int(avg_rgb[1]), int(avg_rgb[2]))

        # Boost saturation
        if self.saturation_boost > 1.0:
            result = boost_saturation(result, self.saturation_boost)

        return result


class ModeProfile(ColorProfile):
    """
    Improved Mode Profile

    Finds the most frequent color after filtering extremes and
    using saturation-weighted counting.
    """

    name = "Mode"
    description = "Most common color, with extremes filtered"

    def __init__(self, filter_extremes: bool = True, quantization_bits: int = 5):
        self.filter_extremes = filter_extremes
        self.quantization_bits = quantization_bits  # 5 = 32 levels, 4 = 16 levels

    def extract(self, pixels: np.ndarray) -> Tuple[int, int, int]:
        if len(pixels) == 0:
            return (0, 0, 0)

        # Filter extreme pixels
        if self.filter_extremes:
            pixels = filter_extreme_pixels(pixels, low=20, high=235)

        if len(pixels) == 0:
            return (0, 0, 0)

        # Quantize colors
        shift = 8 - self.quantization_bits
        quantized = (pixels >> shift) << shift

        # Pack RGB into single integer for counting
        packed = (quantized[:, 0].astype(np.int32) << 16) | \
                 (quantized[:, 1].astype(np.int32) << 8) | \
                 quantized[:, 2].astype(np.int32)

        # Find mode
        unique, counts = np.unique(packed, return_counts=True)
        mode_packed = unique[np.argmax(counts)]

        # Unpack
        r = (mode_packed >> 16) & 0xFF
        g = (mode_packed >> 8) & 0xFF
        b = mode_packed & 0xFF

        return (int(r), int(g), int(b))


class VibrantProfile(ColorProfile):
    """
    Vibrant Profile

    Finds the most saturated color with decent brightness.
    Great for bold, eye-catching visualizations.
    """

    name = "Vibrant"
    description = "Most saturated color for bold results"

    def __init__(self, min_saturation: float = 0.25, min_value: float = 0.15, max_value: float = 0.95):
        self.min_saturation = min_saturation
        self.min_value = min_value
        self.max_value = max_value

    def extract(self, pixels: np.ndarray) -> Tuple[int, int, int]:
        if len(pixels) == 0:
            return (0, 0, 0)

        # Convert to HSV
        hsv = rgb_to_hsv(pixels)

        # Filter by saturation and value thresholds
        mask = (hsv[:, 1] >= self.min_saturation) & \
               (hsv[:, 2] >= self.min_value) & \
               (hsv[:, 2] <= self.max_value)

        filtered_pixels = pixels[mask]
        filtered_hsv = hsv[mask]

        if len(filtered_pixels) == 0:
            # Fallback to average if no vibrant pixels
            return AverageProfile(filter_dark=True, saturation_boost=1.0).extract(pixels)

        # Score by saturation * value (vibrancy score)
        scores = filtered_hsv[:, 1] * filtered_hsv[:, 2]

        # Get top 10% most vibrant pixels and average them
        top_k = max(1, len(scores) // 10)
        top_indices = np.argpartition(scores, -top_k)[-top_k:]
        top_pixels = filtered_pixels[top_indices]

        avg_color = np.mean(top_pixels, axis=0).astype(int)

        return (int(avg_color[0]), int(avg_color[1]), int(avg_color[2]))


class DominantProfile(ColorProfile):
    """
    Dominant Profile (K-Means Clustering)

    Uses k-means clustering to find the true dominant color,
    weighted by cluster size and saturation.
    """

    name = "Dominant"
    description = "K-means clustered dominant color"

    def __init__(self, n_clusters: int = 5, prefer_saturated: bool = True):
        self.n_clusters = n_clusters
        self.prefer_saturated = prefer_saturated

    def _simple_kmeans(self, pixels: np.ndarray, n_clusters: int, max_iter: int = 10) -> np.ndarray:
        """Simple k-means implementation for when sklearn is not available."""
        n_pixels = len(pixels)

        # Initialize centroids randomly
        indices = np.random.choice(n_pixels, min(n_clusters, n_pixels), replace=False)
        centroids = pixels[indices].astype(np.float32)

        for _ in range(max_iter):
            # Assign pixels to nearest centroid
            distances = np.sqrt(np.sum((pixels[:, np.newaxis] - centroids) ** 2, axis=2))
            labels = np.argmin(distances, axis=1)

            # Update centroids
            new_centroids = np.array([
                pixels[labels == i].mean(axis=0) if np.sum(labels == i) > 0 else centroids[i]
                for i in range(n_clusters)
            ])

            if np.allclose(centroids, new_centroids):
                break
            centroids = new_centroids

        return centroids, labels

    def extract(self, pixels: np.ndarray) -> Tuple[int, int, int]:
        if len(pixels) == 0:
            return (0, 0, 0)

        # Filter dark pixels first
        pixels = filter_dark_pixels(pixels, threshold=15)

        if len(pixels) < self.n_clusters:
            return tuple(np.mean(pixels, axis=0).astype(int))

        # Subsample for performance if too many pixels
        if len(pixels) > 1000:
            indices = np.random.choice(len(pixels), 1000, replace=False)
            sample_pixels = pixels[indices]
        else:
            sample_pixels = pixels

        # Run k-means
        if HAS_SKLEARN:
            kmeans = MiniBatchKMeans(n_clusters=self.n_clusters, n_init=1, max_iter=10)
            kmeans.fit(sample_pixels)
            centroids = kmeans.cluster_centers_
            labels = kmeans.predict(sample_pixels)
        else:
            centroids, labels = self._simple_kmeans(sample_pixels, self.n_clusters)

        # Count cluster sizes
        cluster_sizes = np.array([np.sum(labels == i) for i in range(self.n_clusters)])

        if self.prefer_saturated:
            # Calculate saturation of each centroid
            hsv_centroids = rgb_to_hsv(centroids.astype(np.uint8).reshape(-1, 3))
            saturations = hsv_centroids[:, 1]

            # Score = size * (1 + saturation)
            scores = cluster_sizes * (1 + saturations)
        else:
            scores = cluster_sizes

        # Pick the best cluster
        best_idx = np.argmax(scores)
        best_color = centroids[best_idx]

        return (int(best_color[0]), int(best_color[1]), int(best_color[2]))


class MedianProfile(ColorProfile):
    """
    Median Profile

    Uses median RGB values for robustness against outliers.
    More balanced than average, resistant to extreme values.
    """

    name = "Median"
    description = "Median color, resistant to outliers"

    def __init__(self, filter_dark: bool = True):
        self.filter_dark = filter_dark

    def extract(self, pixels: np.ndarray) -> Tuple[int, int, int]:
        if len(pixels) == 0:
            return (0, 0, 0)

        if self.filter_dark:
            pixels = filter_dark_pixels(pixels, threshold=15)

        if len(pixels) == 0:
            return (0, 0, 0)

        median_r = int(np.median(pixels[:, 0]))
        median_g = int(np.median(pixels[:, 1]))
        median_b = int(np.median(pixels[:, 2]))

        return (median_r, median_g, median_b)


class HighlightProfile(ColorProfile):
    """
    Highlight Profile

    Emphasizes bright pixels (85th-95th percentile brightness).
    Great for capturing the "glow" of scenes, neon aesthetics.
    """

    name = "Highlight"
    description = "Bright pixels emphasized, captures scene glow"

    def __init__(self, low_percentile: float = 75, high_percentile: float = 95):
        self.low_percentile = low_percentile
        self.high_percentile = high_percentile

    def extract(self, pixels: np.ndarray) -> Tuple[int, int, int]:
        if len(pixels) == 0:
            return (0, 0, 0)

        # Calculate brightness
        brightness = np.mean(pixels, axis=1)

        # Get percentile thresholds
        low_thresh = np.percentile(brightness, self.low_percentile)
        high_thresh = np.percentile(brightness, self.high_percentile)

        # Filter to highlight range (avoid pure white)
        mask = (brightness >= low_thresh) & (brightness <= high_thresh) & (brightness < 245)
        highlight_pixels = pixels[mask]

        if len(highlight_pixels) == 0:
            # Fallback to average
            return AverageProfile().extract(pixels)

        # Average the highlights
        avg = np.mean(highlight_pixels, axis=0).astype(int)

        return (int(avg[0]), int(avg[1]), int(avg[2]))


class WeightedProfile(ColorProfile):
    """
    Weighted Saturation Profile

    Weights each pixel by its saturation when averaging.
    More saturated pixels contribute more to the final color.
    """

    name = "Weighted"
    description = "Saturation-weighted average for colorful results"

    def __init__(self, filter_dark: bool = True):
        self.filter_dark = filter_dark

    def extract(self, pixels: np.ndarray) -> Tuple[int, int, int]:
        if len(pixels) == 0:
            return (0, 0, 0)

        if self.filter_dark:
            pixels = filter_dark_pixels(pixels, threshold=15)

        if len(pixels) == 0:
            return (0, 0, 0)

        # Calculate saturation for each pixel
        hsv = rgb_to_hsv(pixels)
        saturations = hsv[:, 1]

        # Add small epsilon to avoid division by zero
        weights = saturations + 0.01

        # Weighted average
        weighted_sum = np.sum(pixels * weights[:, np.newaxis], axis=0)
        total_weight = np.sum(weights)

        avg = (weighted_sum / total_weight).astype(int)

        return (int(avg[0]), int(avg[1]), int(avg[2]))


# Profile registry for easy access
PROFILES = {
    ProfileType.AVERAGE: AverageProfile,
    ProfileType.MODE: ModeProfile,
    ProfileType.VIBRANT: VibrantProfile,
    ProfileType.DOMINANT: DominantProfile,
    ProfileType.MEDIAN: MedianProfile,
    ProfileType.HIGHLIGHT: HighlightProfile,
    ProfileType.WEIGHTED: WeightedProfile,
}


def get_profile(profile_type: ProfileType) -> ColorProfile:
    """Get an instance of the specified profile type."""
    return PROFILES[profile_type]()


def get_all_profile_names() -> list:
    """Get list of all profile names for UI."""
    return [p.value for p in ProfileType]


def get_profile_description(profile_type: ProfileType) -> str:
    """Get the description for a profile type."""
    return PROFILES[profile_type]().description
